{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.config import Config\n",
    "config_file = \"configs/config.json\"\n",
    "cfg = Config.from_file(config_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nvme/likunchang/miniconda3/envs/chatcap/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from models.videochat import VideoChat\n",
    "from utils.easydict import EasyDict\n",
    "import torch\n",
    "\n",
    "from transformers import StoppingCriteria, StoppingCriteriaList\n",
    "\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import numpy as np\n",
    "from decord import VideoReader, cpu\n",
    "import torchvision.transforms as T\n",
    "from models.video_transformers import (\n",
    "    GroupNormalize, GroupScale, GroupCenterCrop, \n",
    "    Stack, ToTorchFormatTensor\n",
    ")\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:18<00:00,  6.20s/it]\n"
     ]
    }
   ],
   "source": [
    "model = VideoChat(config=cfg.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(torch.device(cfg.device))\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(conv):\n",
    "    ret = conv.system + conv.sep\n",
    "    for role, message in conv.messages:\n",
    "        if message:\n",
    "            ret += role + \": \" + message + conv.sep\n",
    "        else:\n",
    "            ret += role + \":\"\n",
    "    return ret\n",
    "\n",
    "\n",
    "def get_context_emb(conv, model, img_list):\n",
    "    prompt = get_prompt(conv)\n",
    "    print(prompt)\n",
    "    if '<VideoHere>' in prompt:\n",
    "        prompt_segs = prompt.split('<VideoHere>')\n",
    "    else:\n",
    "        prompt_segs = prompt.split('<ImageHere>')\n",
    "    assert len(prompt_segs) == len(img_list) + 1, \"Unmatched numbers of image placeholders and images.\"\n",
    "    seg_tokens = [\n",
    "        model.llama_tokenizer(\n",
    "            seg, return_tensors=\"pt\", add_special_tokens=i == 0).to(\"cuda:0\").input_ids\n",
    "        # only add bos to the first seg\n",
    "        for i, seg in enumerate(prompt_segs)\n",
    "    ]\n",
    "    seg_embs = [model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]\n",
    "    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]\n",
    "    mixed_embs = torch.cat(mixed_embs, dim=1)\n",
    "    return mixed_embs\n",
    "\n",
    "\n",
    "def ask(text, conv):\n",
    "    conv.messages.append([conv.roles[0], text + '\\n'])\n",
    "        \n",
    "\n",
    "class StoppingCriteriaSub(StoppingCriteria):\n",
    "    def __init__(self, stops=[], encounters=1):\n",
    "        super().__init__()\n",
    "        self.stops = stops\n",
    "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n",
    "        for stop in self.stops:\n",
    "            if torch.all((stop == input_ids[0][-len(stop):])).item():\n",
    "                return True\n",
    "        return False\n",
    "    \n",
    "    \n",
    "def answer(conv, model, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,\n",
    "               repetition_penalty=1.0, length_penalty=1, temperature=1.0):\n",
    "    stop_words_ids = [\n",
    "        torch.tensor([835]).to(\"cuda:0\"),\n",
    "        torch.tensor([2277, 29937]).to(\"cuda:0\")]  # '###' can be encoded in two different ways.\n",
    "    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])\n",
    "    \n",
    "    conv.messages.append([conv.roles[1], None])\n",
    "    embs = get_context_emb(conv, model, img_list)\n",
    "    outputs = model.llama_model.generate(\n",
    "        inputs_embeds=embs,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "        stopping_criteria=stopping_criteria,\n",
    "        num_beams=num_beams,\n",
    "        do_sample=True,\n",
    "        min_length=min_length,\n",
    "        top_p=top_p,\n",
    "        repetition_penalty=repetition_penalty,\n",
    "        length_penalty=length_penalty,\n",
    "        temperature=temperature,\n",
    "    )\n",
    "    output_token = outputs[0]\n",
    "    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it\n",
    "            output_token = output_token[1:]\n",
    "    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it\n",
    "            output_token = output_token[1:]\n",
    "    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)\n",
    "    output_text = output_text.split('###')[0]  # remove the stop sign '###'\n",
    "    output_text = output_text.split('Assistant:')[-1].strip()\n",
    "    conv.messages[-1][1] = output_text\n",
    "    return output_text, output_token.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_index(num_frames, num_segments):\n",
    "    seg_size = float(num_frames - 1) / num_segments\n",
    "    start = int(seg_size / 2)\n",
    "    offsets = np.array([\n",
    "        start + int(np.round(seg_size * idx)) for idx in range(num_segments)\n",
    "    ])\n",
    "    return offsets\n",
    "\n",
    "\n",
    "def load_video(video_path, num_segments=8, return_msg=False):\n",
    "    vr = VideoReader(video_path, ctx=cpu(0))\n",
    "    num_frames = len(vr)\n",
    "    frame_indices = get_index(num_frames, num_segments)\n",
    "\n",
    "    # transform\n",
    "    crop_size = 224\n",
    "    scale_size = 224\n",
    "    input_mean = [0.48145466, 0.4578275, 0.40821073]\n",
    "    input_std = [0.26862954, 0.26130258, 0.27577711]\n",
    "\n",
    "    transform = T.Compose([\n",
    "        GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),\n",
    "        GroupCenterCrop(crop_size),\n",
    "        Stack(),\n",
    "        ToTorchFormatTensor(),\n",
    "        GroupNormalize(input_mean, input_std) \n",
    "    ])\n",
    "\n",
    "    images_group = list()\n",
    "    for frame_index in frame_indices:\n",
    "        img = Image.fromarray(vr[frame_index].asnumpy())\n",
    "        images_group.append(img)\n",
    "    torch_imgs = transform(images_group)\n",
    "    if return_msg:\n",
    "        fps = float(vr.get_avg_fps())\n",
    "        sec = \", \".join([str(round(f / fps, 1)) for f in frame_indices])\n",
    "        # \" \" should be added in the start and end\n",
    "        msg = f\"The video contains {len(frame_indices)} frames sampled at {sec} seconds.\"\n",
    "        return torch_imgs, msg\n",
    "    else:\n",
    "        return torch_imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The video contains 8 frames sampled at 0.6, 1.8, 3.0, 4.2, 5.5, 6.7, 7.9, 9.1 seconds.\n"
     ]
    }
   ],
   "source": [
    "# vid_path = \"./example/yoga.mp4\"\n",
    "vid_path = \"./example/jesse_dance.mp4\"\n",
    "vid, msg = load_video(vid_path, num_segments=8, return_msg=True)\n",
    "print(msg)\n",
    "    \n",
    "# The model expects inputs of shape: T x C x H x W\n",
    "TC, H, W = vid.shape\n",
    "video = vid.reshape(1, TC//3, 3, H, W).to(\"cuda:0\")\n",
    "\n",
    "img_list = []\n",
    "image_emb, _ = model.encode_img(video)\n",
    "img_list.append(image_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "###Human: <Video><VideoHere></Video> The video contains 8 frames sampled at 0.6, 1.8, 3.0, 4.2, 5.5, 6.7, 7.9, 9.1 seconds.\n",
      "###Human: Explain why it is funny?\n",
      "###Assistant:\n",
      "This is a funny video of a man dancing in a kitchen with his handy family of paper people. The man is clearly having fun, and the fact that the paper people are also moving and dancing alongside him adds to the humor of the video. The man's exuberance and enthusiasm while dancing with the paper people makes the video particularly entertaining. It is a lighthearted and whimsical moment that showcases a man's playful side. The use of paper people adds a touch of creativity and imagination to the video. Overall, it is a funny and heartwarming video.\n"
     ]
    }
   ],
   "source": [
    "chat = EasyDict({\n",
    "#     \"system\": \"You are an AI assistant. A human gives an image or a video and asks some questions. You should give helpful, detailed, and polite answers.\\n\",\n",
    "    \"system\": \"\",\n",
    "    \"roles\": (\"Human\", \"Assistant\"),\n",
    "    \"messages\": [],\n",
    "    \"sep\": \"###\"\n",
    "})\n",
    "\n",
    "chat.messages.append([chat.roles[0], f\"<Video><VideoHere></Video> {msg}\\n\"])\n",
    "# ask(\"Describe the video in detail.\", chat)\n",
    "# ask(\"Is she safe to doing something in the video?\", chat)\n",
    "# ask(\"Where is she?\", chat)\n",
    "# ask(\"Who are you?\", chat)\n",
    "# ask(\"Are you an assistant?\", chat)\n",
    "# ask(\"How can I learn to do as the video?\", chat)\n",
    "# ask(\"Tell me what she did at the fourth second\", chat)\n",
    "# ask(\"Can you provide me some urls to learn to do as the video?\", chat)\n",
    "# ask(\"List the president of America.\", chat)\n",
    "# ask(\"What do you feel from the video?\", chat)\n",
    "# ask(\"Do you think it is funny? Why?\", chat)\n",
    "ask(\"Explain why it is funny?\", chat)\n",
    "# ask(\"Explain why the video is ridiculous?\", chat)\n",
    "\n",
    "llm_message = answer(conv=chat, model=model, img_list=img_list, max_new_tokens=1000)[0]\n",
    "print(llm_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# img_path = \"./example/bear.jpg\"\n",
    "# img_path = \"./example/cxk.png\"\n",
    "img_path = \"./example/dog.png\"\n",
    "img = Image.open(img_path).convert('RGB')\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(\n",
    "            (224, 224), interpolation=InterpolationMode.BICUBIC\n",
    "        ),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n",
    "    ]\n",
    ")\n",
    "\n",
    "img = transform(img).unsqueeze(0).unsqueeze(0).cuda()\n",
    "img_list = []\n",
    "image_emb, _ = model.encode_img(img)\n",
    "img_list.append(image_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "###Human: <Image><ImageHere></Image>\n",
      "###Human: Explain why the image is funny?\n",
      "###Assistant:\n",
      "The image is funny because the puppy is sleeping on a wooden floor with the words, \"just monday\" written underneath it. This implies that the puppy is lazy and is using this as an excuse for not having to do anything else on Monday. The words, \"just monday,\" are written in a bold font, emphasizing the puppy 's laziness and disdain for Monday. The puppy is lying down on the floor, which adds to the humor and implies that it doesn't want to move or do anything on this day of the week. The photo is taken from a slightly upward angle, which adds to the visual impact and creates a comical effect. The caption is written in a playful, humorous font that reinforces the joke and encourages viewers to have a good laugh. The combination of the puppy, the words, and the photo make this image funny and entertaining.\n"
     ]
    }
   ],
   "source": [
    "chat = EasyDict({\n",
    "#     \"system\": \"You are an AI assistant. A human gives an Image and asks questions. You should give helpful, detailed, and polite answers.\\n\",\n",
    "    \"system\": \"\",\n",
    "    \"roles\": (\"Human\", \"Assistant\"),\n",
    "    \"messages\": [],\n",
    "    \"sep\": \"###\"\n",
    "})\n",
    "\n",
    "chat.messages.append([chat.roles[0], \"<Image><ImageHere></Image>\\n\"])\n",
    "# ask(\"Describe the image in detail.\", chat)\n",
    "# ask(\"What's in the image?\", chat)\n",
    "# ask(\"Who are you?\", chat)\n",
    "# ask(\"Are you an assistant?\", chat) \n",
    "# ask(\"Where can I find the animal?\", chat)\n",
    "# ask(\"Give me some urls to find the anmimal\", chat)\n",
    "# ask(\"Give me some urls to learn to do it\", chat)\n",
    "# ask(\"Is it funny? why?\", chat)\n",
    "# ask(\"What do you feel from the image?\", chat)\n",
    "# ask(\"Do you think the image is funny?\", chat)\n",
    "ask(\"Explain why the image is funny?\", chat)\n",
    "# ask(\"What is the text in the image?\", chat)\n",
    "\n",
    "llm_message = answer(conv=chat, model=model, img_list=img_list, max_new_tokens=1000)[0]\n",
    "print(llm_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
